"""Implements the long-short term memory character model.
This version vectorizes over multiple examples, but each string
has a fixed length."""

from os.path import dirname, join

from rnn import build_dataset, concat_and_multiply, one_hot_to_string, sigmoid, string_to_one_hot

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.misc.optimizers import adam
from autograd.scipy.special import logsumexp


def init_lstm_params(input_size, state_size, output_size, param_scale=0.01, rs=npr.RandomState(0)):
    def rp(*shape):
        return rs.randn(*shape) * param_scale

    return {
        "init cells": rp(1, state_size),
        "init hiddens": rp(1, state_size),
        "change": rp(input_size + state_size + 1, state_size),
        "forget": rp(input_size + state_size + 1, state_size),
        "ingate": rp(input_size + state_size + 1, state_size),
        "outgate": rp(input_size + state_size + 1, state_size),
        "predict": rp(state_size + 1, output_size),
    }


def lstm_predict(params, inputs):
    def update_lstm(input, hiddens, cells):
        change = np.tanh(concat_and_multiply(params["change"], input, hiddens))
        forget = sigmoid(concat_and_multiply(params["forget"], input, hiddens))
        ingate = sigmoid(concat_and_multiply(params["ingate"], input, hiddens))
        outgate = sigmoid(concat_and_multiply(params["outgate"], input, hiddens))
        cells = cells * forget + ingate * change
        hiddens = outgate * np.tanh(cells)
        return hiddens, cells

    def hiddens_to_output_probs(hiddens):
        output = concat_and_multiply(params["predict"], hiddens)
        return output - logsumexp(output, axis=1, keepdims=True)  # Normalize log-probs.

    num_sequences = inputs.shape[1]
    hiddens = np.repeat(params["init hiddens"], num_sequences, axis=0)
    cells = np.repeat(params["init cells"], num_sequences, axis=0)

    output = [hiddens_to_output_probs(hiddens)]
    for input in inputs:  # Iterate over time steps.
        hiddens, cells = update_lstm(input, hiddens, cells)
        output.append(hiddens_to_output_probs(hiddens))
    return output


def lstm_log_likelihood(params, inputs, targets):
    logprobs = lstm_predict(params, inputs)
    loglik = 0.0
    num_time_steps, num_examples, _ = inputs.shape
    for t in range(num_time_steps):
        loglik += np.sum(logprobs[t] * targets[t])
    return loglik / (num_time_steps * num_examples)


if __name__ == "__main__":
    num_chars = 128

    # Learn to predict our own source code.
    text_filename = join(dirname(__file__), "lstm.py")
    train_inputs = build_dataset(text_filename, sequence_length=30, alphabet_size=num_chars, max_lines=60)

    init_params = init_lstm_params(input_size=128, output_size=128, state_size=40, param_scale=0.01)

    def print_training_prediction(weights):
        print("Training text                         Predicted text")
        logprobs = np.asarray(lstm_predict(weights, train_inputs))
        for t in range(logprobs.shape[1]):
            training_text = one_hot_to_string(train_inputs[:, t, :])
            predicted_text = one_hot_to_string(logprobs[:, t, :])
            print(training_text.replace("\n", " ") + "|" + predicted_text.replace("\n", " "))

    def training_loss(params, iter):
        return -lstm_log_likelihood(params, train_inputs, train_inputs)

    def callback(weights, iter, gradient):
        if iter % 10 == 0:
            print("Iteration", iter, "Train loss:", training_loss(weights, 0))
            print_training_prediction(weights)

    # Build gradient of loss function using autograd.
    training_loss_grad = grad(training_loss)

    print("Training LSTM...")
    trained_params = adam(training_loss_grad, init_params, step_size=0.1, num_iters=1000, callback=callback)

    print()
    print("Generating text from LSTM...")
    num_letters = 30
    for t in range(20):
        text = ""
        for i in range(num_letters):
            seqs = string_to_one_hot(text, num_chars)[:, np.newaxis, :]
            logprobs = lstm_predict(trained_params, seqs)[-1].ravel()
            text += chr(npr.choice(len(logprobs), p=np.exp(logprobs)))
        print(text)
